from collections import deque
import json,copy

# 事件列表形如（按照一定顺序重新编号）：
# Events List: [
# {
#    "Event ID": "Event 1",
#    "Diegesis": "Slow-Pace Narration",
#    "Description": "Character A receives a mission from Character B.",
#    "Background": "Background information of Event 1.",
#    "Detailed": "Detailed explanation of Event 1.",
#    "End Suspense": "The suspense left by Event 1.",
#    "Associated Characters": ["Character A", "Character B"]
# },
# ]
# 关系列表形如：
# Relationships List:[
# {
#    "Events": Event 1 -> Event 2
#    "Strength": [The strength of the causal relationship, e.g. HIGH, MEDIUM, LOW ]
#    "Details": [Here is the detailed relationship of this tow events.]
# },
# ]

# CHAPTER方法：
# EVENTS LIST: [] #这里放事件列表（事件按照Topo序重新编号）
# CHAPTERS STORYLINE: #这里按照节放边
# - Chapter 1: [] #第一节的关系列表（第一节的所有前驱边）
# - Chapter 2: [] #第二节的关系列表（第一节的所有前驱边）

# Topo方法:
# EVENTS LIST: [] #这里放事件列表（事件按照Topo序重新编号）
# STAGES STORYLINE: #这里按照层级放边
# - Stage 1: [] #第一层的关系列表（第一层的所有前驱边）
# - Stage 2: [] #第二层的关系列表（第一层的所有前驱边）


# DFS方法:
# EVENTS LIST: [] #这里放事件列表（事件按照dfs序重新编号）
# PLOTLINES STORYLINE: #这里按照故事线放边
# - Plotline 1: [] #第一条线的关系列表（第一条线的所有前驱边）
# - Plotline 2: [] #第二条线的关系列表（第一条线的所有前驱边）

event_index = 0
event_ID_dict = {}

def creat_new_ID(event):
    global event_index,event_ID_dict
    event_index+= 1
    old_ID,new_ID = event['Event ID'],f"Event {event_index}"
    if old_ID in event_ID_dict:
        print(f"Warning: Event ID {old_ID} already exists.")
    event_ID_dict[old_ID] = new_ID #建立编号映射
    return new_ID

def creat_new_Event(event):
    new_ID = creat_new_ID(event)
    event_new = copy.deepcopy(event)
    event_new['Event ID'] = new_ID
    return event_new

def creat_new_Event_by_ID(Act_ID, plot_lines):
    act_id,event_ID = Act_ID.replace("Act ","").replace("Event ","").split('-')
    chain = plot_lines[int(act_id)-1]['Plot Chains List']
    for event in chain:
        if event['Event ID'] == Act_ID:
            return creat_new_Event(event)
    print(f"    [Warning] creat_new_Event_by_ID(): Event {Act_ID} not found in plot lines.")
    return None

def topo_traverse(plot_lines, plots_edges): #拓扑排序
    """
    Perform Topo traversal on the plot graph while keeping track of event details.
    """
    events_list = []

    stage_events = []
    queue = deque()

    degree,depth = {},{} #入度和层级深度
    for edge in plots_edges:
        degree[edge[1]] = degree.get(edge[1], 0) + 1
    now_depth = 1
    for act in plot_lines:
        plot_chains = act.get('Plot Chains List',[])
        for event in plot_chains:
            if degree.get(event['Event ID'], 0) == 0:
                old_ID = event['Event ID']
                queue.append(old_ID) #入度为0的点初始入队
                depth[old_ID] = 1 #第一层
    while queue:
        event_ID = queue.popleft()
        event_new = creat_new_Event_by_ID(event_ID,plot_lines) # 新建编号映射
        if depth[event_ID] == now_depth:
            stage_events.append(event_new)
        else:
            events_list.append(stage_events)
            stage_events = [event_new]
            now_depth = depth[event_ID]

        # Find all edges that start from the current event
        for edge in plots_edges:
            if edge[0] == event_ID:
                degree[edge[1]] -= 1
                if degree[edge[1]] == 0:
                    queue.append(edge[1])
                    depth[edge[1]] = depth[edge[0]] + 1

    events_list.append(stage_events) #还有最后一段加进去
    
    return events_list

def dfs_traverse(plot_lines, plots_edges):
    """
    Perform DFS traversal on the plot graph while keeping track of event details.
    """

    events_list = []
    visited = set()


    def dfs(event_ID,root,plotline_stage):
        # print(f"dfs({event_ID}): plotline_stage={[event['Event ID'] for event in plotline_stage]}")
        visited.add(event_ID)
        root = copy.deepcopy(event_ID)
        event_new = creat_new_Event_by_ID(root,plot_lines) # 新建编号映射
        plotline_stage.append(event_new)
        flag=False
        # Find all edges that start from the current event
        for edge in plots_edges:
            if edge[0]==event_ID and edge[1] not in visited:
                flag=True
                dfs(edge[1],root,plotline_stage)
                root=event_ID
                # plotline_stage=[event_new]
                plotline_stage=[]
        if flag==False: #如果是叶子结点
            events_list.append(plotline_stage)
            
    for act in plot_lines:
        plot_chains = act.get('Plot Chains List',[])
        for event in plot_chains:
            old_ID = event['Event ID']
            if old_ID not in visited:
                root = copy.deepcopy(old_ID)
                dfs(root,root,[])

    return events_list

MODE = {
    "chapter": "CHAPTERS",
    "Topo": "STAGES",
    "DFS": "PLOTLINES"
}
MODE_ = {
    "chapter": "Chapter",
    "Topo": "Stage",
    "DFS": "Plotline"
}

# # Helper function to generate event names based on mode
# def generate_event_name(event, index, mode, stage=1):
#     if mode == "Topo":
#         return f"Stage {stage}-{event.split('-')[-1]}"
#     elif mode == "DFS":
#         return f"Plotline {index+1}-{event.split('-')[-1]}"
#     else:#elif mode == "chapter":
#         return f"Chapter {stage}-{event.split('-')[-1]}"

def build_diagram(plot_lines, plots_edges, mode="Topo"):
    """
    Construct a string representation of the plot lines and edges as a diagram.
    """
    global event_index,event_ID_dict
    event_index = 0
    event_ID_dict = {}
    events_list = [] #事件顺序重排编号
    edges_list = []

    if mode == "Topo":
        events_list = topo_traverse(plot_lines, plots_edges)
    elif mode == "DFS":
        events_list = dfs_traverse(plot_lines, plots_edges)
    else: #elif mode == "chapter":
        # Group events by chapter (Act)
        for act in plot_lines:
            plot_chains = act.get('Plot Chains List',[])
            chapter_events = []

            for event in plot_chains:
                event_new = creat_new_Event(event)
                chapter_events.append(event_new)
            events_list.append(chapter_events)
        
        # for index,act in enumerate(plot_lines):
        #     plot_chains = act.get('Plot Chains List',[])
        #     chapter_event_IDs = [event['Event ID'] for event in events_list[index]]
        #     chapter_edges=[]
        #     for edge in plots_edges:
        #         old_ID1,old_ID2 = edge[0],edge[1]
        #         new_ID1,new_ID2 = event_ID_dict[old_ID1],event_ID_dict[old_ID2]
        #         edge_new = copy.deepcopy(edge)
        #         edge_new[0],edge_new[1] = new_ID1,new_ID2
        #         edge_new[3].replace(old_ID1,new_ID1).replace(old_ID2,new_ID2) #防止reason里面有原来的旧编号
                
        #         if new_ID2 in chapter_event_IDs: #取出终点在当前节的边
        #             chapter_edges.append(edge_new)

        #     edges_list.append(chapter_edges)

    for act_events in events_list:
        act_event_IDs = [event['Event ID'] for event in act_events]
        act_edges=[]
        for edge in plots_edges:
            old_ID1,old_ID2 = edge[0],edge[1]
            new_ID1,new_ID2 = event_ID_dict[old_ID1],event_ID_dict[old_ID2]
            edge_new = copy.deepcopy(edge)
            edge_new[0],edge_new[1] = new_ID1,new_ID2
            edge_new[3].replace(old_ID1,new_ID1).replace(old_ID2,new_ID2) #防止reason里面有原来的旧编号
            if new_ID2 in act_event_IDs: #取出终点在当前节的边
                act_edges.append(edge_new)
        edges_list.append(act_edges)


    history = "EVENTS LIST: [\n"
    for act in events_list:
        for event in act:
            history += f"    {event},\n"
        history += "\n"
    history += "]\n\n"
    history += f"{MODE[mode]} STORYLINE:\n"
    for index,act in enumerate(edges_list):
        history += f"- {MODE_[mode]} {index+1}: [\n"
        for idx,edge in enumerate(act):
            history += f"""    - Event Relationship {idx+1}:
        - Events: {edge[0]} -> {edge[1]}
        - Strength: {edge[2]}
        - Details: {edge[3]}
"""
            # history += f"    {edge},\n"
        history += "]\n"
    return event_ID_dict,events_list,edges_list,history